"""
Module for ICD10 code preprocessing
"""
__author__ = "shubhranshu-shekhar"
__date__ = "10/14/24"

import bz2
import pickle
import json
import os
import icd10 # pip install icd10-cm
import simple_icd_10_cm as cm # pip install simple-icd-10-cm
import numpy as np
import pandas as pd
from collections import defaultdict
from scipy.linalg import block_diag


def icd10_similarity(candicate_codes=None, jaccard=False):
    if candicate_codes is None:
        raise NotImplementedError("Please provide candidate ICD9 codes to compute similarity")
    
    # separate candidate set into different top level chapter based on icd10 codes
    candidate_groups_dct = get_top_level_groups_icd10(candicate_codes)

    X = []  # for storing similarity for each icd10 category
    icd10_codes_used = []
    if jaccard:
        for k, chapter_codes in candidate_groups_dct.items():
            icd10_codes_used.append(chapter_codes)
            
            # compute pairwise similarity for each group of subtree_codes
            if k == '999': # catching codes that are not found in my cm library
                pair_distance = np.eye(len(chapter_codes))
            else:
                pair_distance = get_jaccard_distance_icd10(chapter_codes) # find the valid ones
            X.append(pair_distance)
    else:
        raise NotImplementedError("Only jaccard is implemented at this time")

    return block_diag(*X), np.array([v_ for sub_ in icd10_codes_used for v_ in sub_])


def jaccard(a, b):
    a = set(a)
    b = set(b)
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))


def get_top_level_groups_icd10(candicate_codes):
    ret_dict = defaultdict(list)

    for cd_ in candicate_codes:
        # if cd_ == '20210':
        #     print('what?')
        if cm.is_valid_item(cd_):
            top_node = cm.get_ancestors(cd_)[-1]
            ret_dict[top_node].append(cd_)
        else:
            ret_dict['999'].append(cd_) # for codes that are not in the library

    return ret_dict


def get_jaccard_distance_icd10(icds):
    n_codes = len(icds)
    sim = np.zeros((n_codes, n_codes))
    for i in range(n_codes):
        descr_i = get_desc_till_root(icds[i])
        
        for j in range(i, n_codes):
            if i == j:
                sim[i, j] = 1
            else:
                descr_j = get_desc_till_root(icds[j])
                sim[i, j] = jaccard(descr_i.split(), descr_j.split())

            # symmetry
            sim[j, i] = sim[i, j]
    return sim

def get_desc_till_root(cd):
    descr_ = [cm.get_description(cd)] + [cm.get_description(c_) for c_ in cm.get_ancestors(cd)[:-1]]
    descr_ = (' '.join(descr_)).lower()
    return descr_


def process_ER_data():
    # check if processed data is laready located at "'../data/ER/prvdr_icd10_jaccard_df.bz2.pkl'"
    if os.path.exists('../data/ER/prvdr_icd10_jaccard_df.bz2.pkl'):
        with bz2.BZ2File('../data/ER/prvdr_icd10_jaccard_df.bz2.pkl', 'rb') as f:
            df_jaccard = pickle.load(f)
        print("Returning preloaded data.")
        return df_jaccard

    # Load icd9 data for year 2017 -- Using ER data only # actually contain ICD10 codes
    provider_icd9_dgns = json.load(open('../output_ER/2017/provider_icd9_dgns.json'))
    provider_icd9_prcdr = json.load(open('../output_ER/2017/provider_icd9_prcdr.json'))
    
    # create dataframe from dict
    data, index = list(provider_icd9_dgns.values()), list(provider_icd9_dgns.keys())
    df1 = pd.DataFrame(data, index=index).fillna(0)

    print("Raw ER data loaded.")

    # to compute the jccard sim incorporated matrix, load icd similarity
    with bz2.BZ2File("../data/icd10icd10sim.bz2.pkl", 'rb') as f:
        prior_computed_sim = pickle.load(f)
    
    (icd10icd10sim, icd10_codes) = (prior_computed_sim[0], prior_computed_sim[1])
    print("Precomputed ICD data loaded.")

    # now I need to incorporate ICDcode sismilarity. Remeber not all codes will be used in ER data,
    # therefore I first select ICDcodes that are used, and then do a matrix multiplication
    mask_selected_icd10 = np.ones_like(icd10_codes, dtype=bool)
    df1c = df1.columns
    for i, c_ in enumerate(icd10_codes):
        if c_ not in df1c:
            mask_selected_icd10[i] = False

    selected_icd10_codes = icd10_codes[mask_selected_icd10]
    selected_icd10icd10sim = icd10icd10sim[np.ix_(mask_selected_icd10, mask_selected_icd10)]

    df1_jaccard = df1[selected_icd10_codes] @ selected_icd10icd10sim # note I'm using selected codes
    df1_jaccard.columns = selected_icd10_codes # rearrange coloumn names abased on selected codes
    print("ICD similarity incorporated loaded.")
    
    data, index = list(provider_icd9_prcdr.values()), list(provider_icd9_prcdr.keys())
    df2 = pd.DataFrame(data, index=index).fillna(0)


    df_jaccard = pd.merge(df1_jaccard, df2, left_index=True, right_index=True)
    print("Data prepared.")
    
    # save jaccard df -- for later user
    with bz2.BZ2File('../data/ER/prvdr_icd10_jaccard_df.bz2.pkl', 'wb') as f:
        pickle.dump(df_jaccard, f)
    print("Data saved.")
    return df_jaccard


def main():
    provider_icd9_dgns = json.load(open('../output/2017/provider_icd9_dgns.json')) # actually contain ICD10 codes
    provider_icd9_prcdr = json.load(open('../output/2017/provider_icd9_prcdr.json')) # actually contain ICD10 codes
    # create dataframe from dict - dignosis
    data, index = list(provider_icd9_dgns.values()), list(provider_icd9_dgns.keys())
    df1 = pd.DataFrame(data, index=index).fillna(0)

    icd10_candidate_codes = list(df1.columns)

    df1_jaccard = df1[icd10_codes] @ icd10icd10sim
    df1_jaccard.columns = icd10_codes
    df1_jaccard.head()

    # create dataframe from dict - procedure
    data, index = list(provider_icd9_prcdr.values()), list(provider_icd9_prcdr.keys())
    df2 = pd.DataFrame(data, index=index).fillna(0)

    df = pd.merge(df1, df2, left_index=True, right_index=True)
    df_jaccard = pd.merge(df1_jaccard, df2, left_index=True, right_index=True)

    # save df 
    with bz2.BZ2File('../data/prvdr_icd10_df.bz2.pkl', 'wb') as f:
        pickle.dump(df, f)

    # save jaccard df 
    with bz2.BZ2File('../data/prvdr_icd10_jaccard_df.bz2.pkl', 'wb') as f:
        pickle.dump(df_jaccard, f)


    icd10icd10sim, icd10_codes = icd10_similarity(candicate_codes=icd10_candidate_codes, 
                                                jaccard=True)
    with bz2.BZ2File("../data/icd10icd10sim.bz2.pkl", 'wb') as f:
        pickle.dump([icd10icd10sim, icd10_codes], f)


if __name__ == "__main__":
    main()